
import numpy as np

gpu_enable = True
try:
    import cupy as cp
    cupy = cp
except ImportError:
    #即使环境中没有安装cupy也不会出现任何错误
    gpu_enable = False

from dezero import Variable

def get_array_module(x):
    """根据数据返回相应的模块"""
    if isinstance(x, Variable):
        x = x.data

    if not gpu_enable:
        return np

    xp = cp.get_array_module(x)
    return xp

def as_numpy(x):
    """将数据变为numpy.ndarray"""
    if isinstance(x, Variable):
        x = x.data

    if np.isscalar(x):
        return np.array(x)
    elif isinstance(x, np.ndarray):
        return x

    return cp.asnumpy(x)

def as_cupy(x):
    """将数据变为cupy.ndarray"""
    if isinstance(x, Variable):
        x = x.data

    if not gpu_enable:
        raise Exception('CuPy cannot be loaded. Please install CuPy!')

    return cp.array(x)